#!/usr/bin/python3
import sys
import re

# part 1

# (x, y)
rocks = [(3, 0), (2, 2), (2, 2), (0, 3), (1, 1)]
solid = [
    [(0, 0), (1, 0), (2, 0), (3, 0)],
    [(1, 0), (0, 1), (1, 1), (2, 1), (1, 2)],
    [(0, 0), (1, 0), (2, 0), (2, 1), (2, 2)],
    [(0, 0), (0, 1), (0, 2), (0, 3)],
    [(0, 0), (0, 1), (1, 0), (1, 1)],
]

def jet_movement_possible(grid, rock, lx, ly, rx, ry):
    return not collision(grid, rock, lx, ly) and lx >= 0 and rx < 7

def collision(grid, rock, lx, ly):
    for sx, sy in solid[rock]:
        if (sx + lx, sy + ly) in grid:
            return True
    return False

def print_grid(grid):
    for y in range(max([y for x, y in grid]) + 1)[::-1]:
        for x in range(7):
            print('#' if (x, y) in grid else '.', end='')
        print()

def part1(rock_count, jet=None):
    if not jet:
        jet = next(sys.stdin).strip()
    jetlen = len(jet)
    jetx = [-1 if c == '<' else 1 for c in jet]
    jetindex = 0
    #print(jet, jetx, jetlen)

    rock = 0
    rocklen = len(rocks)
    grid = set([(x, 0) for x in range(0, 7)])
    hx, hy = 2, 4
    for fallen in range(rock_count):
        width, height = rocks[rock]
        lx, ly = hx, hy
        rx, ry = lx + width, ly + height
        #print('rock', rock, 'spawned', lx, ly, rx, ry)
        while not collision(grid, rock, lx, ly):
            # get moved by jet
            jx = jetx[jetindex]
            if jet_movement_possible(grid, rock, lx + jx, ly, rx + jx, ry):
                lx += jx
                rx += jx
            jj = jetindex
            jetindex = (jetindex + 1) % jetlen
            # fall down
            ly -= 1
            ry -= 1
            #print(f'{lx} {ly} {rx} {ry} {jx}@{jj}')
        # collision happened, handle collision
        #print('collision', lx, ly, rx, ry)
        for sx, sy in solid[rock]:
            # add but with one height higher to avoid collision
            grid.add((sx + lx, sy + ly + 1))
        hy = max(ry + 1, hy - 4) + 4
        #hy = ry + 5
        rock = (rock + 1) % rocklen

    print(sorted(grid))
    print_grid(grid)
    print(f'tower height is {hy - 4}')
    return hy - 4

def find_cycle(diffs):
    for i, a in enumerate(diffs):
        for j, b in enumerate(diffs[i + 1:]):
            if a == b:
                return len(diffs) - 1 - i
    return -1

def part2(rock_count):
    jet = next(sys.stdin).strip()
    jetlen = len(jet)
    jetx = [-1 if c == '<' else 1 for c in jet]
    jetindex = 0
    print(jet, jetx, jetlen)

    rock = 0
    rocklen = len(rocks)
    grid = set([(x, 0) for x in range(0, 7)])
    hx, hy = 2, 4

    # part 2 stuff
    cheight = 0
    dheight = 0
    ddheight = 0
    lheight = 0
    llheight = 0
    lrockc = 0
    drockc = 0

    hdiffs = []
    rdiffs = []
    hdiffs_left = None
    rdiffs_left = None
    hdiffs_right = None
    rdiffs_right = None

    # game loop
    for fallen in range(rock_count):
        width, height = rocks[rock]
        lx, ly = hx, hy
        rx, ry = lx + width, ly + height
        #print('rock', rock, 'spawned', lx, ly, rx, ry)
        while not collision(grid, rock, lx, ly):
            # get moved by jet
            jx = jetx[jetindex]
            if jet_movement_possible(grid, rock, lx + jx, ly, rx + jx, ry):
                lx += jx
                rx += jx
            jj = jetindex
            jetindex = (jetindex + 1) % jetlen
            # fall down
            ly -= 1
            ry -= 1
            #print(f'{lx} {ly} {rx} {ry} {jx}@{jj}')
        # collision happened, handle collision
        #print('collision', lx, ly, rx, ry)
        for sx, sy in solid[rock]:
            # add but with one height higher to avoid collision
            grid.add((sx + lx, sy + ly + 1))
        hy = max(ry + 1, hy - 4) + 4
        #hy = ry + 5
        rock = (rock + 1) % rocklen

        if rock == 0:
            row_full = True
            for x in range(2, 7 - 2):
                if (x, hy - 4) not in grid:
                    row_full = False
                    break
            if row_full:
                print(cheight, lheight, llheight, dheight, ddheight, fallen, lrockc, drockc)
                cheight = hy - 4
                dheight = cheight - lheight
                ddheight = lheight - llheight
                drockc = fallen - lrockc

                hdiffs.append(dheight)
                rdiffs.append(drockc)
                cycle = find_cycle(hdiffs)
                print('cycle is', cycle, hdiffs)
                if cycle >= 0:
                    print('found cycle')
                    hdiffs_left = hdiffs[:-cycle - 1]
                    rdiffs_left = rdiffs[:-cycle - 1]
                    hdiffs_right = hdiffs[-cycle:]
                    rdiffs_right = rdiffs[-cycle:]
                    break

                lheight = cheight
                llheight = lheight
                lrockc = fallen
    print(hdiffs_left, hdiffs_right)
    print(rdiffs_left, rdiffs_right)
    #print(sorted(grid))
    calculated_height = sum(hdiffs_left)
    remaining_rocks = rock_count - sum(rdiffs_left)
    hdr = sum(hdiffs_right)
    rdr = sum(rdiffs_right)
    calculated_height += remaining_rocks // rdr * hdr
    remaining_rocks %= remaining_rocks // rdr
    index = 0
    while remaining_rocks > rdiffs_right[index]:
        calculated_height += hdiffs_right[index]
        remaining_rocks -= hdiffs_right[index]
        index = (index + 1) % cycle
        print(remaining_rocks)
    calculated_height += part1(remaining_rocks, jet) - 1
    remaining_rocks -= remaining_rocks
    print(calculated_height, remaining_rocks)
    # this is too complicated

if sys.argv[1] in '1':
    part1(2022)
else:
    part2(1000000000000)
